{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 304,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Copyright (c) Facebook, Inc. and its affiliates.\n",
    "\"\"\"\n",
    "\n",
    "from collections import Counter, defaultdict\n",
    "import argparse\n",
    "import ast\n",
    "import fileinput\n",
    "import json\n",
    "import os\n",
    "\n",
    "right_answer_count = Counter()\n",
    "wrong_answer_count = Counter()\n",
    "\n",
    "# compile sets of allowed answers\n",
    "allowed_answers = defaultdict(set)\n",
    "command = None\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 305,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_gold_set(gold_set):\n",
    "    with open(gold_set, \"r\") as f:\n",
    "        for line in f:\n",
    "            line = line.strip()\n",
    "            if line == \"\":\n",
    "                continue\n",
    "            if line.startswith(\"{\"):\n",
    "                try:\n",
    "                    allowed_answers[command].add(line)\n",
    "                except:\n",
    "                    print(\"Bad allowed answer:\", line)\n",
    "                    raise\n",
    "            else:\n",
    "                command = line\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 307,
   "metadata": {},
   "outputs": [],
   "source": [
    "read_gold_set('data/qual_gold_answers.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 308,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_dicts(action_dict, allowed_dict):\n",
    "    # action_dict = ast.literal_eval(action_dict)\n",
    "    allowed_dict = ast.literal_eval(allowed_dict)\n",
    "    if \"repeat\" in allowed_dict:\n",
    "        if \"repeat\" not in action_dict:\n",
    "\n",
    "            return False\n",
    "        val = allowed_dict[\"repeat\"]\n",
    "        val2 = action_dict[\"repeat\"]\n",
    "        if val != val2:\n",
    "            if val[0] != val2[0]:\n",
    "                return False\n",
    "            val_dict1 = val[1]\n",
    "            val_dict2 = val2[1]\n",
    "            for k, v in val_dict2.items():\n",
    "                if k == \"repeat_dir\":\n",
    "                    continue\n",
    "                if k not in val_dict1 or v != val_dict1[k]:\n",
    "                    return False\n",
    "\n",
    "    for k, v in allowed_dict.items():\n",
    "        if k == \"repeat\":\n",
    "            continue\n",
    "        if k not in action_dict or action_dict[k] != v:\n",
    "            return False\n",
    "    return True\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 309,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_wrong_stats(dict1, dict2_list, sentence):\n",
    "    \"\"\" {'repeat',\n",
    "        'schematic',\n",
    "        'dialogue_type',\n",
    "        'action_type',\n",
    "        'has_block_type',\n",
    "        'reference_object',\n",
    "        'tag_val',\n",
    "        'filters',\n",
    "        'location',\n",
    "        'target_action_type'}\"\"\"\n",
    "    st = {}\n",
    "    for d in dict2_list: # ground truth\n",
    "        dict2 = ast.literal_eval(d)\n",
    "        for k, v in dict2.items():\n",
    "            if k not in dict1:\n",
    "                    st['missing_key_'+k] = st.get('missing_key_'+k, 0)+1 \n",
    "                    return st\n",
    "            if v != dict1[k]:\n",
    "                if k =='action_type' and v[1] != dict1[k][1]:\n",
    "                    if sentence == \"dig two small holes behind the pool\" and dict1[k][1] in ['build', 'dig']:\n",
    "                        continue\n",
    "                    st[\"action_type_diff\"] = st.get(\"action_type_diff\", \"\")+ \"_\"+ dict1[k][1]\n",
    "                    st[k+\"_value_wrong\"] = st.get(k+\"_value_wrong\", 0)+1\n",
    "                    return st\n",
    "                st[k+\"_span_wrong\"] = st.get(k+\"_span_wrong\", 0)+1\n",
    "                return st\n",
    "    return st"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 310,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_workers(worker_file):\n",
    "    worker_stats = {}\n",
    "    wrong_stats = {}\n",
    "    for k, v in allowed_answers.items():\n",
    "        wrong_stats[k] = {}\n",
    "    \n",
    "    with open(worker_file) as f:\n",
    "        # one worker at a time\n",
    "        for line in f.readlines():\n",
    "            right_count = 0\n",
    "            wrong_count = 0\n",
    "            worker_id, answers = line.strip().split(\"\\t\")\n",
    "            answer_dicts = ast.literal_eval(answers) # all three answers with -- sentence: dict\n",
    "\n",
    "            # if worker didn't answer all questions, ignore\n",
    "            if len(answer_dicts.keys()) < 3:\n",
    "                print(\"Skipping: %r completed only %r\" % (worker_id, len(answer_dicts.keys())))\n",
    "                continue\n",
    "\n",
    "            # otherwise read all answers\n",
    "            # k is sentence, v is dict\n",
    "            for k, v in answer_dicts.items():\n",
    "                # sentence has to be in allowed_answers, unnecessary check\n",
    "                if k not in allowed_answers:\n",
    "                    print(\"The sentence: %r is missing.\" %(sentence))\n",
    "                \n",
    "                # if answer doesn't match any allowed answer\n",
    "                if not any(compare_dicts(v, x) for x in allowed_answers[k]):\n",
    "                    wrong_count += 1\n",
    "                    # Ananlyze the mistake\n",
    "                    stats = get_wrong_stats(v, allowed_answers[k], k)\n",
    "                    # stats = get_wrong_stats(v, d)\n",
    "                    for a, b in stats.items():\n",
    "                        if a not in wrong_stats[k]:\n",
    "                            wrong_stats[k][a] = b\n",
    "                        elif type(b)==int:\n",
    "                            wrong_stats[k][a] += b\n",
    "                        elif type(b)==str:\n",
    "                            wrong_stats[k][a] += \"_\" + b\n",
    "                else:\n",
    "                    right_count += 1\n",
    "            # print(\"-\" * 30)\n",
    "            worker_stats[worker_id] = int((right_count / (right_count + wrong_count)) * 100)\n",
    "\n",
    "    return worker_stats, wrong_stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 311,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'A3M3ZFDVYMBJ1X': 66} {'fill all the holes with water': {'action_type_diff': '_destroy', 'action_type_value_wrong': 1}, 'dig two small holes behind the pool': {}, 'go to the red cube between the trees': {}}\n"
     ]
    }
   ],
   "source": [
    "# worker_stats, wrong_stats = evaluate_workers('data/qual_test_answers/2nd_500_qual_user_answers.txt')\n",
    "worker_stats, wrong_stats = evaluate_workers('/Users/kavyasrinet/Downloads/test_q.txt')\n",
    "print(worker_stats, wrong_stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 302,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fill all the holes with water\n",
      "other_action_type_values {'move', 'action', 'noop', 'stop', 'otheraction', 'tag', 'undo', 'resume', 'destroy', 'build', 'freebuild', 'answer', 'dance', 'spawn', 'dig', 'composite', 'copy'}\n",
      "action_type_value_wrong 127\n",
      "missing_key_has_block_type 125\n",
      "missing_key_repeat 59\n",
      "reference_object_span_wrong 55\n",
      "has_block_type_span_wrong 33\n",
      "repeat_span_wrong 28\n",
      "missing_key_action_type 3\n",
      "missing_key_reference_object 1\n",
      "Total mistakes: 431\n",
      "********************\n",
      "dig two small holes behind the pool\n",
      "other_action_type_values {'fill', 'move', 'action', 'resume', 'stop', 'otheraction', 'tag', 'undo', 'destroy', 'freebuild', 'answer', 'dance', 'spawn', 'composite', 'copy'}\n",
      "schematic_span_wrong 153\n",
      "action_type_value_wrong 130\n",
      "location_span_wrong 64\n",
      "missing_key_location 44\n",
      "missing_key_repeat 10\n",
      "repeat_span_wrong 2\n",
      "missing_key_action_type 1\n",
      "Total mistakes: 404\n",
      "********************\n",
      "go to the red cube between the trees\n",
      "other_action_type_values {'fill', 'undo', 'action', 'resume', 'stop', 'otheraction', 'tag', 'noop', 'destroy', 'freebuild', 'build', 'answer', 'dance', 'spawn', 'dig', 'composite', 'copy'}\n",
      "location_span_wrong 188\n",
      "action_type_value_wrong 126\n",
      "missing_key_action_type 3\n",
      "Total mistakes: 317\n",
      "********************\n"
     ]
    }
   ],
   "source": [
    "import operator\n",
    "from pprint import pprint\n",
    "for k, v in wrong_stats.items():\n",
    "    print(k)\n",
    "    total = 0\n",
    "    if \"action_type_diff\" in v:\n",
    "        vals = v[\"action_type_diff\"].split(\"_\")\n",
    "        vs = set(vals)\n",
    "        vs.remove('')\n",
    "        print(\"other_action_type_values\", vs)\n",
    "        v.pop(\"action_type_diff\")\n",
    "    sorted_d = dict(sorted(v.items(), key=operator.itemgetter(1),reverse=True))\n",
    "    \n",
    "    for a, b in sorted_d.items():\n",
    "        print(a, b)\n",
    "        if type(b)==int:\n",
    "            total += int(b)\n",
    "    print(\"Total mistakes: %r\"%(total))\n",
    "    print(\"*\"*20)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 303,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "with open('data/qual_test_workers/second_500_workers.txt', 'w') as f:\n",
    "    for k, v in worker_stats.items():\n",
    "        f.write(k +\"\\t\" + str(v) + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key : absent\n",
    "    key: tp, fp, tn, fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
